from recbole.trainer import HyperTuning
import argparse
import logging
from logging import getLogger

import torch
import pickle
from recbole_cdr.config import CDRConfig
from recbole.data import save_split_dataloaders, load_split_dataloaders
from recbole.utils import init_logger, get_model, get_trainer, init_seed, set_color
from recbole_cdr.config import CDRConfig
from recbole_cdr.data import create_dataset, data_preparation
from recbole_cdr.utils import get_model, get_trainer
import warnings
warnings.filterwarnings('ignore')
parameter_dict = {
   'neg_sampling': None,
}
def objective_function(config_dict=None, config_file_list=None, saved=True):
   r""" The default objective_function used in HyperTuning

   Args:
       config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
       config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
       saved (bool, optional): Whether to save the model. Defaults to ``True``.
   """

   config = CDRConfig(model=None, config_file_list=config_file_list, config_dict=config_dict)
   init_seed(config['seed'], config['reproducibility'])
   logging.basicConfig(level=logging.ERROR)
   dataset = create_dataset(config)
   train_data, valid_data, test_data = data_preparation(config, dataset)
   init_seed(config['seed'], config['reproducibility'])
   model = get_model(config['model'])(config, train_data.dataset).to(config['device'])
   trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)
   best_valid_score, best_valid_result = trainer.fit(train_data, valid_data, verbose=False, saved=saved)
   test_result = trainer.evaluate(test_data, load_best_model=saved)
   print(test_result)

   return {
      'best_valid_score': best_valid_score,
      'valid_score_bigger': config['valid_metric_bigger'],
      'best_valid_result': best_valid_result,
      'test_result': test_result
   }

parser = argparse.ArgumentParser()
parser.add_argument('--yaml',default='hyper_fixed/Douban_Movie_Book.yaml', type=str, help='yaml name')
parser.add_argument('--yaml_hyper_test', default='hyper_flexible/unicdr_best_ori.test', type=str, help='yaml list name')
parser.add_argument('--output_file', default='Experiments_results/a.result', type=str, help='yaml list name')
args, _ = parser.parse_known_args()
hp = HyperTuning(objective_function=objective_function, algo='exhaustive', max_evals=100, params_file=args.yaml_hyper_test, params_dict=parameter_dict, fixed_config_file_list=[args.yaml])



hp.run()
# export result to the file
hp.export_result(output_file=args.output_file)
# print best parameters
print('best params: ', hp.best_params)
# print best result
print('best result: ',hp.best_params)
print(hp.params2result[hp.params2str(hp.best_params)])